{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bidirection LSTM - IMDB sentiment classification\n",
    "\n",
    "see **https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "KERAS_MODEL_FILEPATH = '../../demos/data/imdb_bidirectional_lstm/imdb_bidirectional_lstm.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(1337)  # for reproducibility\n",
    "\n",
    "from keras.preprocessing import sequence\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Embedding, LSTM, Input, Bidirectional\n",
    "from keras.datasets import imdb\n",
    "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz\n",
      "17465344/17464789 [==============================] - 9s 1us/step\n",
      "25000 train sequences\n",
      "25000 test sequences\n",
      "Pad sequences (samples x time)\n",
      "X_train shape: (25000, 200)\n",
      "X_test shape: (25000, 200)\n"
     ]
    }
   ],
   "source": [
    "max_features = 20000\n",
    "maxlen = 200  # cut texts after this number of words (among top max_features most common words)\n",
    "\n",
    "print('Loading data...')\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=max_features)\n",
    "print(len(X_train), 'train sequences')\n",
    "print(len(X_test), 'test sequences')\n",
    "\n",
    "print(\"Pad sequences (samples x time)\")\n",
    "X_train = sequence.pad_sequences(X_train, maxlen=maxlen)\n",
    "X_test = sequence.pad_sequences(X_test, maxlen=maxlen)\n",
    "print('X_train shape:', X_train.shape)\n",
    "print('X_test shape:', X_test.shape)\n",
    "y_train = np.array(y_train)\n",
    "y_test = np.array(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(max_features, 64, input_length=maxlen))\n",
    "model.add(Bidirectional(LSTM(32)))\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "\n",
    "# try using different optimizers and different optimizer configs\n",
    "model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 25000 samples, validate on 25000 samples\n",
      "Epoch 1/10\n",
      "Epoch 00001: val_acc improved from -inf to 0.86916, saving model to ../../demos/data/imdb_bidirectional_lstm/imdb_bidirectional_lstm.h5\n",
      " - 77s - loss: 0.4602 - acc: 0.7830 - val_loss: 0.3180 - val_acc: 0.8692\n",
      "Epoch 2/10\n",
      "Epoch 00002: val_acc improved from 0.86916 to 0.87272, saving model to ../../demos/data/imdb_bidirectional_lstm/imdb_bidirectional_lstm.h5\n",
      " - 74s - loss: 0.2304 - acc: 0.9169 - val_loss: 0.3164 - val_acc: 0.8727\n",
      "Epoch 3/10\n",
      "Epoch 00003: val_acc did not improve\n",
      " - 74s - loss: 0.1490 - acc: 0.9520 - val_loss: 0.3412 - val_acc: 0.8637\n",
      "Epoch 4/10\n",
      "Epoch 00004: val_acc did not improve\n",
      " - 75s - loss: 0.1026 - acc: 0.9694 - val_loss: 0.4146 - val_acc: 0.8626\n",
      "Epoch 00004: early stopping\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f34a2714ef0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Model saving callback\n",
    "checkpointer = ModelCheckpoint(filepath=KERAS_MODEL_FILEPATH, monitor='val_acc', verbose=1, save_best_only=True)\n",
    "\n",
    "# Early stopping\n",
    "early_stopping = EarlyStopping(monitor='val_acc', verbose=1, patience=2)\n",
    "\n",
    "# train\n",
    "batch_size = 128\n",
    "epochs = 10\n",
    "model.fit(X_train, y_train, \n",
    "          validation_data=[X_test, y_test],\n",
    "          batch_size=batch_size, epochs=epochs, verbose=2,\n",
    "          callbacks=[checkpointer, early_stopping])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "**sample data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://s3.amazonaws.com/text-datasets/imdb_word_index.json\n",
      "1646592/1641221 [==============================] - 1s 0us/step\n"
     ]
    }
   ],
   "source": [
    "word_index = imdb.get_word_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_dict = {idx: word for word, idx in word_index.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"i'll keep it short and brief the people who wrote the story lines for this show are genius the actors are just perfect for the roles they play - character is legendary and they have so much chemistry on screen which makes it what it is a very successful comedy br br when i saw first saw the new episodes which is probably going back just over 6 7 months i wondered what had happened to paul i was gutted to find out that he had died when i - google he was so funny and played his character to perfection an over protective dad who likes to keep his daughters out of the limelight and away from boys br br the comedy i think has gone from strength to strength even without paul in it br br plus i think most people would enjoy this watching it\""
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = []\n",
    "for idx in X_train[0]:\n",
    "    if idx >= 3:\n",
    "        sample.append(word_dict[idx-3])\n",
    "    elif idx == 2:\n",
    "        sample.append('-')\n",
    "' '.join(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../demos/data/imdb_bidirectional_lstm/imdb_dataset_word_index_top20000.json', 'w') as f:\n",
    "    f.write(json.dumps({word: idx for word, idx in word_index.items() if idx < max_features}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../demos/data/imdb_bidirectional_lstm/imdb_dataset_word_dict_top20000.json', 'w') as f:\n",
    "    f.write(json.dumps({idx: word for word, idx in word_index.items() if idx < max_features}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_test_data = []\n",
    "for i in np.random.choice(range(X_test.shape[0]), size=1000, replace=False):\n",
    "    sample_test_data.append({'values': X_test[i].tolist(), 'label': y_test[i].tolist()})\n",
    "    \n",
    "with open('../../demos/data/imdb_bidirectional_lstm/imdb_dataset_test.json', 'w') as f:\n",
    "    f.write(json.dumps(sample_test_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
